{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Human numbers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastai import *\n",
    "from fastai.text import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs=64"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[PosixPath('/data1/jhoward/git/course-v3/nbs/dl1/data/human_numbers/valid.txt'),\n",
       " PosixPath('/data1/jhoward/git/course-v3/nbs/dl1/data/human_numbers/train.txt'),\n",
       " PosixPath('/data1/jhoward/git/course-v3/nbs/dl1/data/human_numbers/models')]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "path = untar_data(URLs.HUMAN_NUMBERS)\n",
    "path.ls()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def readnums(d): return [', '.join(o.strip() for o in open(path/d).readlines())]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'one, two, three, four, five, six, seven, eight, nine, ten, eleven, twelve, thirt'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_txt = readnums('train.txt'); train_txt[0][:80]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "' nine thousand nine hundred ninety eight, nine thousand nine hundred ninety nine'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "valid_txt = readnums('valid.txt'); valid_txt[0][-80:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train = TextList(train_txt, path=path)\n",
    "valid = TextList(valid_txt, path=path)\n",
    "\n",
    "src = ItemLists(path=path, train=train, valid=valid).label_for_lm()\n",
    "data = src.databunch(bs=bs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'xxbos one , two , three , four , five , six , seven , eight , nine , ten , eleve'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train[0].text[:80]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "13017"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(data.valid_ds[0][0].data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(70, 3)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.bptt, len(data.valid_dl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2.905580357142857"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "13017/70/bs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "it = iter(data.valid_dl)\n",
    "x1,y1 = next(it)\n",
    "x2,y2 = next(it)\n",
    "x3,y3 = next(it)\n",
    "it.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "12928"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x1.numel()+x2.numel()+x3.numel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([95, 64]), torch.Size([95, 64]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x1.shape,y1.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([67, 64]), torch.Size([67, 64]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x2.shape,y2.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 2, 18, 10, 11,  8, 18, 10, 12,  8, 18, 10, 13,  8, 18, 10, 14,  8, 18,\n",
       "        10, 15,  8, 18, 10, 16,  8, 18, 10, 17,  8, 18, 10, 18,  8, 18, 10, 19,\n",
       "         8, 18, 10, 28,  8, 18, 10, 29,  8, 18, 10, 30,  8, 18, 10, 31,  8, 18,\n",
       "        10, 32,  8, 18, 10, 33,  8, 18, 10, 34,  8, 18, 10, 35,  8, 18, 10, 36,\n",
       "         8, 18, 10, 37,  8, 18, 10, 20,  8, 18, 10, 20, 11,  8, 18, 10, 20, 12,\n",
       "         8, 18, 10, 20, 13], device='cuda:0')"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x1[:,0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([18, 10, 11,  8, 18, 10, 12,  8, 18, 10, 13,  8, 18, 10, 14,  8, 18, 10,\n",
       "        15,  8, 18, 10, 16,  8, 18, 10, 17,  8, 18, 10, 18,  8, 18, 10, 19,  8,\n",
       "        18, 10, 28,  8, 18, 10, 29,  8, 18, 10, 30,  8, 18, 10, 31,  8, 18, 10,\n",
       "        32,  8, 18, 10, 33,  8, 18, 10, 34,  8, 18, 10, 35,  8, 18, 10, 36,  8,\n",
       "        18, 10, 37,  8, 18, 10, 20,  8, 18, 10, 20, 11,  8, 18, 10, 20, 12,  8,\n",
       "        18, 10, 20, 13,  8], device='cuda:0')"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y1[:,0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "v = data.valid_ds.vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'xxbos eight thousand one , eight thousand two , eight thousand three , eight thousand four , eight thousand five , eight thousand six , eight thousand seven , eight thousand eight , eight thousand nine , eight thousand ten , eight thousand eleven , eight thousand twelve , eight thousand thirteen , eight thousand fourteen , eight thousand fifteen , eight thousand sixteen , eight thousand seventeen , eight thousand eighteen , eight thousand nineteen , eight thousand twenty , eight thousand twenty one , eight thousand twenty two , eight thousand twenty three'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "v.textify(x1[:,0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'eight thousand one , eight thousand two , eight thousand three , eight thousand four , eight thousand five , eight thousand six , eight thousand seven , eight thousand eight , eight thousand nine , eight thousand ten , eight thousand eleven , eight thousand twelve , eight thousand thirteen , eight thousand fourteen , eight thousand fifteen , eight thousand sixteen , eight thousand seventeen , eight thousand eighteen , eight thousand nineteen , eight thousand twenty , eight thousand twenty one , eight thousand twenty two , eight thousand twenty three ,'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "v.textify(y1[:,0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "', eight thousand twenty four , eight thousand twenty five , eight thousand twenty six , eight thousand twenty seven , eight thousand twenty eight , eight thousand twenty nine , eight thousand thirty , eight thousand thirty one , eight thousand thirty two , eight thousand thirty three , eight thousand thirty four , eight thousand thirty five , eight thousand thirty six , eight thousand'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "v.textify(x2[:,0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'thirty seven , eight thousand thirty eight , eight thousand thirty nine , eight thousand forty , eight thousand forty one , eight thousand forty two , eight thousand forty three , eight thousand forty four , eight thousand forty'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "v.textify(x3[:,0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "', eight thousand forty six , eight thousand forty seven , eight thousand forty eight , eight thousand forty nine , eight thousand fifty , eight thousand fifty one , eight thousand fifty two , eight thousand fifty three , eight thousand fifty four , eight thousand fifty five , eight thousand fifty six , eight thousand fifty seven , eight thousand fifty eight , eight thousand fifty nine , eight thousand sixty , eight thousand sixty one , eight thousand sixty two , eight thousand sixty three , eight thousand sixty four , eight'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "v.textify(x1[:,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'thousand sixty five , eight thousand sixty six , eight thousand sixty seven , eight thousand sixty eight , eight thousand sixty nine , eight thousand seventy , eight thousand seventy one , eight thousand seventy two , eight thousand seventy three , eight thousand seventy four , eight thousand seventy five , eight thousand seventy six , eight thousand seventy seven , eight thousand seventy eight'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "v.textify(x2[:,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "', eight thousand seventy nine , eight thousand eighty , eight thousand eighty one , eight thousand eighty two , eight thousand eighty three , eight thousand eighty four , eight thousand eighty five , eight thousand eighty six ,'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "v.textify(x3[:,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'nine hundred ninety , nine thousand nine hundred ninety one , nine thousand nine hundred ninety two , nine thousand nine hundred ninety three , nine thousand nine hundred ninety four , nine thousand nine hundred ninety five , nine'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "v.textify(x3[:,-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table>  <col width='5%'>  <col width='95%'>  <tr>\n",
       "    <th>idx</th>\n",
       "    <th>text</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>0</th>\n",
       "    <th>xxbos eight thousand one , eight thousand two , eight thousand three , eight thousand four , eight thousand five , eight thousand six , eight thousand seven , eight thousand eight , eight thousand nine , eight thousand ten , eight thousand eleven , eight thousand twelve , eight thousand thirteen , eight thousand fourteen , eight thousand fifteen , eight thousand sixteen , eight thousand seventeen ,</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>1</th>\n",
       "    <th>, eight thousand forty six , eight thousand forty seven , eight thousand forty eight , eight thousand forty nine , eight thousand fifty , eight thousand fifty one , eight thousand fifty two , eight thousand fifty three , eight thousand fifty four , eight thousand fifty five , eight thousand fifty six , eight thousand fifty seven , eight thousand fifty eight , eight thousand fifty nine</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>2</th>\n",
       "    <th>thousand eighty seven , eight thousand eighty eight , eight thousand eighty nine , eight thousand ninety , eight thousand ninety one , eight thousand ninety two , eight thousand ninety three , eight thousand ninety four , eight thousand ninety five , eight thousand ninety six , eight thousand ninety seven , eight thousand ninety eight , eight thousand ninety nine , eight thousand one hundred , eight</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>3</th>\n",
       "    <th>thousand one hundred twenty three , eight thousand one hundred twenty four , eight thousand one hundred twenty five , eight thousand one hundred twenty six , eight thousand one hundred twenty seven , eight thousand one hundred twenty eight , eight thousand one hundred twenty nine , eight thousand one hundred thirty , eight thousand one hundred thirty one , eight thousand one hundred thirty two , eight</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>4</th>\n",
       "    <th>fifty two , eight thousand one hundred fifty three , eight thousand one hundred fifty four , eight thousand one hundred fifty five , eight thousand one hundred fifty six , eight thousand one hundred fifty seven , eight thousand one hundred fifty eight , eight thousand one hundred fifty nine , eight thousand one hundred sixty , eight thousand one hundred sixty one , eight thousand one hundred</th>\n",
       "  </tr>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "data.show_batch(ds_type=DatasetType.Valid)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Single fully connected model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = src.databunch(bs=bs, bptt=3, max_len=0, p_bptt=1.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([3, 64]), torch.Size([3, 64]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x,y = data.one_batch()\n",
    "x.shape,y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "38"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nv = len(v.itos); nv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nh=64\n",
    "class Model1(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.i_h = nn.Embedding(nv,nh)\n",
    "        self.h_h = nn.Linear(nh,nh)\n",
    "        self.h_o = nn.Linear(nh,nv)\n",
    "        self.bn = nn.BatchNorm1d(nh)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        h = torch.zeros(x.shape[1], nh).to(device=x.device)\n",
    "        for xi in x:\n",
    "            h = h + self.i_h(xi)\n",
    "            h = F.relu(self.h_h(h))\n",
    "            h = self.bn(h)\n",
    "        return self.h_o(h)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss4(input,target): return F.cross_entropy(input, target[-1])\n",
    "def acc4 (input,target): return accuracy(input, target[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = Learner(data, Model1(), loss_func=loss4, metrics=acc4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "Total time: 00:06 <p><table style='width:300px; margin-bottom:10px'>\n",
       "  <tr>\n",
       "    <th>epoch</th>\n",
       "    <th>train_loss</th>\n",
       "    <th>valid_loss</th>\n",
       "    <th>acc4</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>1</th>\n",
       "    <th>3.576015</th>\n",
       "    <th>3.581396</th>\n",
       "    <th>0.085860</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>2</th>\n",
       "    <th>3.084124</th>\n",
       "    <th>3.054432</th>\n",
       "    <th>0.280863</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>3</th>\n",
       "    <th>2.517673</th>\n",
       "    <th>2.579382</th>\n",
       "    <th>0.383973</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>4</th>\n",
       "    <th>2.187254</th>\n",
       "    <th>2.341024</th>\n",
       "    <th>0.391631</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>5</th>\n",
       "    <th>2.056659</th>\n",
       "    <th>2.261394</th>\n",
       "    <th>0.392559</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>6</th>\n",
       "    <th>2.029062</th>\n",
       "    <th>2.249638</th>\n",
       "    <th>0.392791</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "\n",
       "  </tr>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit_one_cycle(6, 1e-4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Multi fully connected model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = src.databunch(bs=bs, bptt=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([45, 64]), torch.Size([45, 64]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x,y = data.one_batch()\n",
    "x.shape,y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model2(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.i_h = nn.Embedding(nv,nh)\n",
    "        self.h_h = nn.Linear(nh,nh)\n",
    "        self.h_o = nn.Linear(nh,nv)\n",
    "        self.bn = nn.BatchNorm1d(nh)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        h = torch.zeros(x.shape[1], nh).to(device=x.device)\n",
    "        res = []\n",
    "        for xi in x:\n",
    "            h = h + self.i_h(xi)\n",
    "            h = F.relu(self.h_h(h))\n",
    "            h = self.bn(h)\n",
    "            res.append(self.h_o(h))\n",
    "        return torch.stack(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = Learner(data, Model2(), metrics=accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "Total time: 00:06 <p><table style='width:300px; margin-bottom:10px'>\n",
       "  <tr>\n",
       "    <th>epoch</th>\n",
       "    <th>train_loss</th>\n",
       "    <th>valid_loss</th>\n",
       "    <th>accuracy</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>1</th>\n",
       "    <th>3.704546</th>\n",
       "    <th>3.669295</th>\n",
       "    <th>0.023670</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>2</th>\n",
       "    <th>3.606465</th>\n",
       "    <th>3.551982</th>\n",
       "    <th>0.080213</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>3</th>\n",
       "    <th>3.485057</th>\n",
       "    <th>3.433933</th>\n",
       "    <th>0.156405</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>4</th>\n",
       "    <th>3.360244</th>\n",
       "    <th>3.323397</th>\n",
       "    <th>0.293704</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>5</th>\n",
       "    <th>3.245313</th>\n",
       "    <th>3.238923</th>\n",
       "    <th>0.350156</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>6</th>\n",
       "    <th>3.149909</th>\n",
       "    <th>3.181015</th>\n",
       "    <th>0.393054</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>7</th>\n",
       "    <th>3.075431</th>\n",
       "    <th>3.141364</th>\n",
       "    <th>0.404316</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>8</th>\n",
       "    <th>3.022162</th>\n",
       "    <th>3.121332</th>\n",
       "    <th>0.404548</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>9</th>\n",
       "    <th>2.989504</th>\n",
       "    <th>3.118630</th>\n",
       "    <th>0.408416</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>10</th>\n",
       "    <th>2.972034</th>\n",
       "    <th>3.114454</th>\n",
       "    <th>0.408029</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "\n",
       "  </tr>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit_one_cycle(10, 1e-4, pct_start=0.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Maintain state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model3(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.i_h = nn.Embedding(nv,nh)\n",
    "        self.h_h = nn.Linear(nh,nh)\n",
    "        self.h_o = nn.Linear(nh,nv)\n",
    "        self.bn = nn.BatchNorm1d(nh)\n",
    "        self.h = torch.zeros(x.shape[1], nh).cuda()\n",
    "        \n",
    "    def forward(self, x):\n",
    "        res = []\n",
    "        h = self.h\n",
    "        for xi in x:\n",
    "            h = h + self.i_h(xi)\n",
    "            h = F.relu(self.h_h(h))\n",
    "            res.append(h)\n",
    "        self.h = h.detach()\n",
    "        res = torch.stack(res)\n",
    "        res = self.h_o(self.bn(res))\n",
    "        return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = Learner(data, Model3(), metrics=accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "Total time: 00:09 <p><table style='width:300px; margin-bottom:10px'>\n",
       "  <tr>\n",
       "    <th>epoch</th>\n",
       "    <th>train_loss</th>\n",
       "    <th>valid_loss</th>\n",
       "    <th>accuracy</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>1</th>\n",
       "    <th>3.574752</th>\n",
       "    <th>3.487574</th>\n",
       "    <th>0.096380</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>2</th>\n",
       "    <th>3.218008</th>\n",
       "    <th>2.850531</th>\n",
       "    <th>0.269261</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>3</th>\n",
       "    <th>2.640497</th>\n",
       "    <th>2.155723</th>\n",
       "    <th>0.465269</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>4</th>\n",
       "    <th>2.107916</th>\n",
       "    <th>1.925786</th>\n",
       "    <th>0.372293</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>5</th>\n",
       "    <th>1.743533</th>\n",
       "    <th>1.977690</th>\n",
       "    <th>0.366027</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>6</th>\n",
       "    <th>1.461914</th>\n",
       "    <th>1.873596</th>\n",
       "    <th>0.417002</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>7</th>\n",
       "    <th>1.239240</th>\n",
       "    <th>1.885451</th>\n",
       "    <th>0.467923</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>8</th>\n",
       "    <th>1.069399</th>\n",
       "    <th>1.886692</th>\n",
       "    <th>0.476949</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>9</th>\n",
       "    <th>0.943912</th>\n",
       "    <th>1.961975</th>\n",
       "    <th>0.473159</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>10</th>\n",
       "    <th>0.827006</th>\n",
       "    <th>1.950261</th>\n",
       "    <th>0.510674</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>11</th>\n",
       "    <th>0.733765</th>\n",
       "    <th>2.038847</th>\n",
       "    <th>0.520471</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>12</th>\n",
       "    <th>0.658219</th>\n",
       "    <th>1.988615</th>\n",
       "    <th>0.524675</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>13</th>\n",
       "    <th>0.605873</th>\n",
       "    <th>1.973706</th>\n",
       "    <th>0.550201</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>14</th>\n",
       "    <th>0.551433</th>\n",
       "    <th>2.027293</th>\n",
       "    <th>0.540130</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>15</th>\n",
       "    <th>0.519542</th>\n",
       "    <th>2.041594</th>\n",
       "    <th>0.531250</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>16</th>\n",
       "    <th>0.497289</th>\n",
       "    <th>2.111806</th>\n",
       "    <th>0.537891</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>17</th>\n",
       "    <th>0.476476</th>\n",
       "    <th>2.104390</th>\n",
       "    <th>0.534837</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>18</th>\n",
       "    <th>0.458751</th>\n",
       "    <th>2.112886</th>\n",
       "    <th>0.534242</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>19</th>\n",
       "    <th>0.463085</th>\n",
       "    <th>2.067193</th>\n",
       "    <th>0.543007</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>20</th>\n",
       "    <th>0.452624</th>\n",
       "    <th>2.089713</th>\n",
       "    <th>0.542400</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "\n",
       "  </tr>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit_one_cycle(20, 3e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## nn.RNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model4(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.i_h = nn.Embedding(nv,nh)\n",
    "        self.rnn = nn.RNN(nh,nh)\n",
    "        self.h_o = nn.Linear(nh,nv)\n",
    "        self.bn = nn.BatchNorm1d(nh)\n",
    "        self.h = torch.zeros(1, x.shape[1], nh).cuda()\n",
    "        \n",
    "    def forward(self, x):\n",
    "        res,h = self.rnn(self.i_h(x), self.h)\n",
    "        self.h = h.detach()\n",
    "        return self.h_o(self.bn(res))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = Learner(data, Model4(), metrics=accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "Total time: 00:04 <p><table style='width:300px; margin-bottom:10px'>\n",
       "  <tr>\n",
       "    <th>epoch</th>\n",
       "    <th>train_loss</th>\n",
       "    <th>valid_loss</th>\n",
       "    <th>accuracy</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>1</th>\n",
       "    <th>3.502738</th>\n",
       "    <th>3.372026</th>\n",
       "    <th>0.252707</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>2</th>\n",
       "    <th>3.092665</th>\n",
       "    <th>2.714043</th>\n",
       "    <th>0.457998</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>3</th>\n",
       "    <th>2.538071</th>\n",
       "    <th>2.189881</th>\n",
       "    <th>0.467048</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>4</th>\n",
       "    <th>2.057624</th>\n",
       "    <th>1.946149</th>\n",
       "    <th>0.451719</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>5</th>\n",
       "    <th>1.697061</th>\n",
       "    <th>1.923625</th>\n",
       "    <th>0.466471</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>6</th>\n",
       "    <th>1.424962</th>\n",
       "    <th>1.916880</th>\n",
       "    <th>0.487856</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>7</th>\n",
       "    <th>1.221850</th>\n",
       "    <th>2.029671</th>\n",
       "    <th>0.507735</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>8</th>\n",
       "    <th>1.063150</th>\n",
       "    <th>1.911920</th>\n",
       "    <th>0.523128</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>9</th>\n",
       "    <th>0.926894</th>\n",
       "    <th>1.882562</th>\n",
       "    <th>0.541045</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>10</th>\n",
       "    <th>0.801033</th>\n",
       "    <th>1.920954</th>\n",
       "    <th>0.541228</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>11</th>\n",
       "    <th>0.719016</th>\n",
       "    <th>1.874411</th>\n",
       "    <th>0.553914</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>12</th>\n",
       "    <th>0.625660</th>\n",
       "    <th>1.983035</th>\n",
       "    <th>0.558014</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>13</th>\n",
       "    <th>0.574975</th>\n",
       "    <th>1.900878</th>\n",
       "    <th>0.560721</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>14</th>\n",
       "    <th>0.505169</th>\n",
       "    <th>1.893559</th>\n",
       "    <th>0.571627</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>15</th>\n",
       "    <th>0.468173</th>\n",
       "    <th>1.882392</th>\n",
       "    <th>0.576869</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>16</th>\n",
       "    <th>0.430182</th>\n",
       "    <th>1.848286</th>\n",
       "    <th>0.574489</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>17</th>\n",
       "    <th>0.400253</th>\n",
       "    <th>1.899022</th>\n",
       "    <th>0.580929</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>18</th>\n",
       "    <th>0.381917</th>\n",
       "    <th>1.907899</th>\n",
       "    <th>0.579285</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>19</th>\n",
       "    <th>0.365580</th>\n",
       "    <th>1.913658</th>\n",
       "    <th>0.578666</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>20</th>\n",
       "    <th>0.367523</th>\n",
       "    <th>1.918424</th>\n",
       "    <th>0.577197</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "\n",
       "  </tr>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit_one_cycle(20, 3e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2-layer GRU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model5(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.i_h = nn.Embedding(nv,nh)\n",
    "        self.rnn = nn.GRU(nh,nh,2)\n",
    "        self.h_o = nn.Linear(nh,nv)\n",
    "        self.bn = nn.BatchNorm1d(nh)\n",
    "        self.h = torch.zeros(2, bs, nh).cuda()\n",
    "        \n",
    "    def forward(self, x):\n",
    "        res,h = self.rnn(self.i_h(x), self.h)\n",
    "        self.h = h.detach()\n",
    "        return self.h_o(self.bn(res))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = Learner(data, Model5(), metrics=accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "Total time: 00:02 <p><table style='width:300px; margin-bottom:10px'>\n",
       "  <tr>\n",
       "    <th>epoch</th>\n",
       "    <th>train_loss</th>\n",
       "    <th>valid_loss</th>\n",
       "    <th>accuracy</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>1</th>\n",
       "    <th>3.010502</th>\n",
       "    <th>2.602906</th>\n",
       "    <th>0.428063</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>2</th>\n",
       "    <th>2.087371</th>\n",
       "    <th>1.765773</th>\n",
       "    <th>0.532410</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>3</th>\n",
       "    <th>1.311803</th>\n",
       "    <th>1.571677</th>\n",
       "    <th>0.643796</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>4</th>\n",
       "    <th>0.675637</th>\n",
       "    <th>1.594766</th>\n",
       "    <th>0.730275</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>5</th>\n",
       "    <th>0.363373</th>\n",
       "    <th>1.432574</th>\n",
       "    <th>0.760873</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>6</th>\n",
       "    <th>0.188198</th>\n",
       "    <th>1.704319</th>\n",
       "    <th>0.762454</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>7</th>\n",
       "    <th>0.108004</th>\n",
       "    <th>1.716183</th>\n",
       "    <th>0.755837</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>8</th>\n",
       "    <th>0.064206</th>\n",
       "    <th>1.510942</th>\n",
       "    <th>0.763404</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>9</th>\n",
       "    <th>0.040955</th>\n",
       "    <th>1.633394</th>\n",
       "    <th>0.754179</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "    <th>10</th>\n",
       "    <th>0.034651</th>\n",
       "    <th>1.621733</th>\n",
       "    <th>0.755460</th>\n",
       "  </tr>\n",
       "  <tr>\n",
       "\n",
       "  </tr>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "learn.fit_one_cycle(10, 1e-2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## fin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
